from typing import Tuple
import torch
import torch.nn.functional as F
from torch.functional import Tensor

class Loss(object):
    def __init__(self, neg_size):
        self.neg_size = neg_size

    def split(self, batch_score: Tensor) -> Tuple[Tensor, Tensor]:
        batch_size = self._compute_batch_size(batch_score)
        p_score = batch_score[:batch_size]
        n_score = batch_score[batch_size:]
        return p_score, n_score

    def compute(self, batch_score: Tensor, subsampling_weight: Tensor) -> Tensor:
        raise NotImplementedError()

    def _compute_batch_size(self, batch_score: Tensor):
        # If dataloader returns a batch smaller than the usual batch_size
        # Usually at the end of an epoch
        # Alternative: set drop_last=True for PyTorch Dataloader
        return batch_score.shape[0]//(1+self.neg_size)

class CELoss(Loss):
    # See https://github.com/uma-pi1/kge/blob/db908a99df5efe20f960dc3cf57eb57206c2f36c/kge/util/loss.py#L192
    # for KLDivLoss function use
    # and https://discuss.pytorch.org/t/cross-entropy-with-one-hot-targets/13580
    # for one-hot targets thread
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fn = torch.nn.KLDivLoss(reduction='mean')

    def compute(self, batch_score: Tensor) -> Tensor:
        pos_samples_size = self._compute_batch_size(batch_score)

        def per_sample_loss():
            nonlocal batch_score
            # reorder batch_score
            # from (pos_1, ..., pos_n, neg_1_1, ..., neg_1_x, ..., neg_n_x)
            # to (pos_1, neg_1_1, ..., neg_1_x, pos_2, ..., neg_n_x)
            
            batch_score=batch_score.reshape(-1,pos_samples_size).t()
            temp = F.log_softmax(batch_score, dim=-1)
            temp = temp[:, 0]  # only the value for the positive sample is relevant
            return - sum(temp) / len(temp)
            # returns the same as:
            #labels = torch.tensor([[0]*batch_score.shape[-1]]*batch_score.shape[0])
            #labels[..., 0] = 1
            # return self.loss_fn(F.log_softmax(batch_score, dim=-1), F.normalize(labels.float(), p=1, dim=1))


        return per_sample_loss()
        # return per_batch_loss()
class MAELoss(Loss):
    # mean absolute error for attributes
    def compute(self, batch_score: Tensor) -> Tensor:
      
      return torch.mean(abs(batch_score[..., 0] - batch_score[..., 1]))